/* -*- c -*- */
/*@targets
 ** $maxopt baseline
 ** sse2 sse42 xop avx2 avx512_skx
 ** vsx2
 ** neon asimd
 ** vx vxe
 **/

#define NPY_NO_DEPRECATED_API NPY_API_VERSION

#include "simd/simd.h"
#include "numpy/npy_math.h"

#include "arraytypes.h"

#define MIN(a,b) (((a)<(b))?(a):(b))

#if NPY_SIMD
#if NPY_SIMD > 512 || NPY_SIMD < 0
    #error "the following 8/16-bit argmax kernel isn't applicable for larger SIMD"
    // TODO: add special loop for large SIMD width.
    // i.e avoid unroll by x4 should be numerically safe till 2048-bit SIMD width
    // or maybe expand the indices to 32|64-bit vectors(slower).
#endif
/**begin repeat
 * #sfx = u8, s8, u16, s16#
 * #usfx = u8, u8, u16, u16#
 * #bsfx = b8, b8, b16, b16#
 * #idx_max = NPY_MAX_UINT8*2, NPY_MAX_UINT16*2#
 */
/**begin repeat1
 * #intrin = cmpgt, cmplt#
 * #func = argmax, argmin#
 * #op = >, <#
 */
static inline npy_intp
simd_@func@_@sfx@(npyv_lanetype_@sfx@ *ip, npy_intp len)
{
    npyv_lanetype_@sfx@ s_acc = *ip;
    npy_intp ret_idx = 0, i = 0;

    const int vstep = npyv_nlanes_@sfx@;
    const int wstep = vstep*4;
    npyv_lanetype_@usfx@ d_vindices[npyv_nlanes_@sfx@*4];
    for (int vi = 0; vi < wstep; ++vi) {
        d_vindices[vi] = vi;
    }
    const npyv_@usfx@ vindices_0 = npyv_load_@usfx@(d_vindices);
    const npyv_@usfx@ vindices_1 = npyv_load_@usfx@(d_vindices + vstep);
    const npyv_@usfx@ vindices_2 = npyv_load_@usfx@(d_vindices + vstep*2);
    const npyv_@usfx@ vindices_3 = npyv_load_@usfx@(d_vindices + vstep*3);

    const npy_intp max_block = @idx_max@*wstep & -wstep;
    npy_intp len0 = len & -wstep;
    while (i < len0) {
        npyv_@sfx@ acc = npyv_setall_@sfx@(s_acc);
        npyv_@usfx@ acc_indices = npyv_zero_@usfx@();
        npyv_@usfx@ acc_indices_scale = npyv_zero_@usfx@();

        npy_intp n = i + MIN(len0 - i, max_block);
        npy_intp ik = i, i2 = 0;
        for (; i < n; i += wstep, ++i2) {
            npyv_@usfx@ vi = npyv_setall_@usfx@((npyv_lanetype_@usfx@)i2);
            npyv_@sfx@ a = npyv_load_@sfx@(ip + i);
            npyv_@sfx@ b = npyv_load_@sfx@(ip + i + vstep);
            npyv_@sfx@ c = npyv_load_@sfx@(ip + i + vstep*2);
            npyv_@sfx@ d = npyv_load_@sfx@(ip + i + vstep*3);

            // reverse to put lowest index first in case of matched values
            npyv_@bsfx@ m_ba = npyv_@intrin@_@sfx@(b, a);
            npyv_@bsfx@ m_dc = npyv_@intrin@_@sfx@(d, c);
            npyv_@sfx@  x_ba = npyv_select_@sfx@(m_ba, b, a);
            npyv_@sfx@  x_dc = npyv_select_@sfx@(m_dc, d, c);
            npyv_@bsfx@ m_dcba = npyv_@intrin@_@sfx@(x_dc, x_ba);
            npyv_@sfx@  x_dcba = npyv_select_@sfx@(m_dcba, x_dc, x_ba);

            npyv_@usfx@ idx_ba = npyv_select_@usfx@(m_ba, vindices_1, vindices_0);
            npyv_@usfx@ idx_dc = npyv_select_@usfx@(m_dc, vindices_3, vindices_2);
            npyv_@usfx@ idx_dcba = npyv_select_@usfx@(m_dcba, idx_dc, idx_ba);
            npyv_@bsfx@ m_acc = npyv_@intrin@_@sfx@(x_dcba, acc);
            acc = npyv_select_@sfx@(m_acc, x_dcba, acc);
            acc_indices = npyv_select_@usfx@(m_acc, idx_dcba, acc_indices);
            acc_indices_scale = npyv_select_@usfx@(m_acc, vi, acc_indices_scale);
        }
        // reduce
        npyv_lanetype_@sfx@ dacc[npyv_nlanes_@sfx@];
        npyv_lanetype_@usfx@ dacc_i[npyv_nlanes_@sfx@];
        npyv_lanetype_@usfx@ dacc_s[npyv_nlanes_@sfx@];
        npyv_store_@sfx@(dacc, acc);
        npyv_store_@usfx@(dacc_i, acc_indices);
        npyv_store_@usfx@(dacc_s, acc_indices_scale);

        for (int vi = 0; vi < vstep; ++vi) {
            if (dacc[vi] @op@ s_acc) {
                s_acc = dacc[vi];
                ret_idx = ik + (npy_intp)dacc_s[vi]*wstep + dacc_i[vi];
            }
        }
        // get the lowest index in case of matched values
        for (int vi = 0; vi < vstep; ++vi) {
            npy_intp idx = ik + (npy_intp)dacc_s[vi]*wstep + dacc_i[vi];
            if (s_acc == dacc[vi] && ret_idx > idx) {
                ret_idx = idx;
            }
        }
    }
    for (; i < len; ++i) {
        npyv_lanetype_@sfx@ a = ip[i];
        if (a @op@ s_acc) {
            s_acc = a;
            ret_idx = i;
        }
    }
    return ret_idx;
}
/**end repeat1**/
/**end repeat**/
#endif

/**begin repeat
 * #sfx = u32, s32, u64, s64, f32, f64#
 * #usfx = u32, u32, u64, u64, u32, u64#
 * #bsfx = b32, b32, b64, b64, b32, b64#
 * #is_fp = 0*4, 1*2#
 * #is_idx32 = 1*2, 0*2, 1, 0#
 * #chk_simd = NPY_SIMD*4, NPY_SIMD_F32, NPY_SIMD_F64#
 */
#if @chk_simd@
/**begin repeat1
 * #intrin = cmpgt, cmplt#
 * #func = argmax, argmin#
 * #op = >, <#
 * #iop = <, >#
 */
static inline npy_intp
simd_@func@_@sfx@(npyv_lanetype_@sfx@ *ip, npy_intp len)
{
    npyv_lanetype_@sfx@ s_acc = *ip;
    npy_intp ret_idx = 0, i = 0;
    const int vstep = npyv_nlanes_@sfx@;
    const int wstep = vstep*4;
    // loop by a scalar will perform better for small arrays
    if (len < wstep) {
        goto scalar_loop;
    }
    npy_intp len0 = len;
    // guard against wraparound vector addition for 32-bit indices
    // in case of the array length is larger than 16gb
#if @is_idx32@
    if (len0 > NPY_MAX_UINT32) {
        len0 = NPY_MAX_UINT32;
    }
#endif
    // create index for vector indices
    npyv_lanetype_@usfx@ d_vindices[npyv_nlanes_@sfx@*4];
    for (int vi = 0; vi < wstep; ++vi) {
        d_vindices[vi] = vi;
    }
    const npyv_@usfx@ vindices_0 = npyv_load_@usfx@(d_vindices);
    const npyv_@usfx@ vindices_1 = npyv_load_@usfx@(d_vindices + vstep);
    const npyv_@usfx@ vindices_2 = npyv_load_@usfx@(d_vindices + vstep*2);
    const npyv_@usfx@ vindices_3 = npyv_load_@usfx@(d_vindices + vstep*3);
    // initialize vector accumulator for highest values and its indexes
    npyv_@usfx@ acc_indices = npyv_zero_@usfx@();
    npyv_@sfx@ acc = npyv_setall_@sfx@(s_acc);
    for (npy_intp n = len0 & -wstep; i < n; i += wstep) {
        npyv_@usfx@ vi = npyv_setall_@usfx@((npyv_lanetype_@usfx@)i);
        npyv_@sfx@ a = npyv_load_@sfx@(ip + i);
        npyv_@sfx@ b = npyv_load_@sfx@(ip + i + vstep);
        npyv_@sfx@ c = npyv_load_@sfx@(ip + i + vstep*2);
        npyv_@sfx@ d = npyv_load_@sfx@(ip + i + vstep*3);

        // reverse to put lowest index first in case of matched values
        npyv_@bsfx@ m_ba = npyv_@intrin@_@sfx@(b, a);
        npyv_@bsfx@ m_dc = npyv_@intrin@_@sfx@(d, c);
        npyv_@sfx@  x_ba = npyv_select_@sfx@(m_ba, b, a);
        npyv_@sfx@  x_dc = npyv_select_@sfx@(m_dc, d, c);
        npyv_@bsfx@ m_dcba = npyv_@intrin@_@sfx@(x_dc, x_ba);
        npyv_@sfx@  x_dcba = npyv_select_@sfx@(m_dcba, x_dc, x_ba);

        npyv_@usfx@ idx_ba = npyv_select_@usfx@(m_ba, vindices_1, vindices_0);
        npyv_@usfx@ idx_dc = npyv_select_@usfx@(m_dc, vindices_3, vindices_2);
        npyv_@usfx@ idx_dcba = npyv_select_@usfx@(m_dcba, idx_dc, idx_ba);
        npyv_@bsfx@ m_acc = npyv_@intrin@_@sfx@(x_dcba, acc);
        acc = npyv_select_@sfx@(m_acc, x_dcba, acc);
        acc_indices = npyv_select_@usfx@(m_acc, npyv_add_@usfx@(vi, idx_dcba), acc_indices);

    #if @is_fp@
        npyv_@bsfx@ nnan_a = npyv_notnan_@sfx@(a);
        npyv_@bsfx@ nnan_b = npyv_notnan_@sfx@(b);
        npyv_@bsfx@ nnan_c = npyv_notnan_@sfx@(c);
        npyv_@bsfx@ nnan_d = npyv_notnan_@sfx@(d);
        npyv_@bsfx@ nnan_ab = npyv_and_@bsfx@(nnan_a, nnan_b);
        npyv_@bsfx@ nnan_cd = npyv_and_@bsfx@(nnan_c, nnan_d);
        npy_uint64 nnan = npyv_tobits_@bsfx@(npyv_and_@bsfx@(nnan_ab, nnan_cd));
        if ((unsigned long long int)nnan != ((1ULL << vstep) - 1)) {
            npy_uint64 nnan_4[4];
            nnan_4[0] = npyv_tobits_@bsfx@(nnan_a);
            nnan_4[1] = npyv_tobits_@bsfx@(nnan_b);
            nnan_4[2] = npyv_tobits_@bsfx@(nnan_c);
            nnan_4[3] = npyv_tobits_@bsfx@(nnan_d);
            for (int ni = 0; ni < 4; ++ni) {
                for (int vi = 0; vi < vstep; ++vi) {
                    if (!((nnan_4[ni] >> vi) & 1)) {
                        return i + ni*vstep + vi;
                    }
                }
            }
        }
    #endif
    }
    for (npy_intp n = len0 & -vstep; i < n; i += vstep) {
        npyv_@usfx@ vi = npyv_setall_@usfx@((npyv_lanetype_@usfx@)i);
        npyv_@sfx@ a = npyv_load_@sfx@(ip + i);
        npyv_@bsfx@ m_acc = npyv_@intrin@_@sfx@(a, acc);
        acc = npyv_select_@sfx@(m_acc, a, acc);
        acc_indices = npyv_select_@usfx@(m_acc, npyv_add_@usfx@(vi, vindices_0), acc_indices);
    #if @is_fp@
        npyv_@bsfx@ nnan_a = npyv_notnan_@sfx@(a);
        npy_uint64 nnan = npyv_tobits_@bsfx@(nnan_a);
        if ((unsigned long long int)nnan != ((1ULL << vstep) - 1)) {
            for (int vi = 0; vi < vstep; ++vi) {
                if (!((nnan >> vi) & 1)) {
                    return i + vi;
                }
            }
        }
    #endif
    }

    // reduce
    npyv_lanetype_@sfx@ dacc[npyv_nlanes_@sfx@];
    npyv_lanetype_@usfx@ dacc_i[npyv_nlanes_@sfx@];
    npyv_store_@usfx@(dacc_i, acc_indices);
    npyv_store_@sfx@(dacc, acc);

    s_acc = dacc[0];
    ret_idx = dacc_i[0];
    for (int vi = 1; vi < vstep; ++vi) {
        if (dacc[vi] @op@ s_acc) {
            s_acc = dacc[vi];
            ret_idx = (npy_intp)dacc_i[vi];
        }
    }
    // get the lowest index in case of matched values
    for (int vi = 0; vi < vstep; ++vi) {
        if (s_acc == dacc[vi] && ret_idx > (npy_intp)dacc_i[vi]) {
            ret_idx = dacc_i[vi];
        }
    }
scalar_loop:
    for (; i < len; ++i) {
        npyv_lanetype_@sfx@ a = ip[i];
    #if @is_fp@
        if (!(a @iop@= s_acc)) {  // negated, for correct nan handling
    #else
        if (a @op@ s_acc) {
    #endif
            s_acc = a;
            ret_idx = i;
        #if @is_fp@
            if (npy_isnan(s_acc)) {
                // nan encountered, it's maximal
                return ret_idx;
            }
        #endif
        }
    }
    return ret_idx;
}
/**end repeat1**/
#endif // chk_simd
/**end repeat**/

/**begin repeat
 * #TYPE = UBYTE, USHORT, UINT, ULONG, ULONGLONG,
 *         BYTE, SHORT, INT, LONG, LONGLONG,
 *         FLOAT, DOUBLE, LONGDOUBLE#
 *
 * #BTYPE = BYTE, SHORT, INT, LONG, LONGLONG,
 *          BYTE, SHORT, INT, LONG, LONGLONG,
 *          FLOAT, DOUBLE, LONGDOUBLE#
 * #type = npy_ubyte, npy_ushort, npy_uint, npy_ulong, npy_ulonglong,
 *         npy_byte, npy_short, npy_int, npy_long, npy_longlong,
 *         npy_float, npy_double, npy_longdouble#
 *
 * #is_fp = 0*10, 1*3#
 * #is_unsigned = 1*5, 0*5, 0*3#
 */
#undef TO_SIMD_SFX
#if 0
/**begin repeat1
 * #len = 8, 16, 32, 64#
 */
#elif NPY_SIMD && NPY_BITSOF_@BTYPE@ == @len@
    #if @is_fp@
        #define TO_SIMD_SFX(X) X##_f@len@
        #if NPY_BITSOF_@BTYPE@ == 64 && !NPY_SIMD_F64
            #undef TO_SIMD_SFX
        #endif
        #if NPY_BITSOF_@BTYPE@ == 32 && !NPY_SIMD_F32
            #undef TO_SIMD_SFX
        #endif
    #elif @is_unsigned@
        #define TO_SIMD_SFX(X) X##_u@len@
    #else
        #define TO_SIMD_SFX(X) X##_s@len@
    #endif
/**end repeat1**/
#endif

/**begin repeat1
 * #func = argmax, argmin#
 * #op = >, <#
 * #iop = <, >#
 */
NPY_NO_EXPORT int NPY_CPU_DISPATCH_CURFX(@TYPE@_@func@)
(@type@ *ip, npy_intp n, npy_intp *mindx, PyArrayObject *NPY_UNUSED(aip))
{
#if @is_fp@
    if (npy_isnan(*ip)) {
        // nan encountered; it's maximal|minimal
        *mindx = 0;
        return 0;
    }
#endif
#ifdef TO_SIMD_SFX
    *mindx = TO_SIMD_SFX(simd_@func@)((TO_SIMD_SFX(npyv_lanetype)*)ip, n);
    npyv_cleanup();
#else
    @type@ mp = *ip;
    *mindx = 0;
    npy_intp i = 1;

    for (; i < n; ++i) {
        @type@ a = ip[i];
    #if @is_fp@
        if (!(a @iop@= mp)) {  // negated, for correct nan handling
    #else
        if (a @op@ mp) {
    #endif
            mp = a;
            *mindx = i;
        #if @is_fp@
            if (npy_isnan(mp)) {
                // nan encountered, it's maximal|minimal
                break;
            }
        #endif
        }
    }
#endif // TO_SIMD_SFX
    return 0;
}
/**end repeat1**/
/**end repeat**/

NPY_NO_EXPORT int NPY_CPU_DISPATCH_CURFX(BOOL_argmax)
(npy_bool *ip, npy_intp len, npy_intp *mindx, PyArrayObject *NPY_UNUSED(aip))

{
    npy_intp i = 0;
#if NPY_SIMD
    const npyv_u8 zero = npyv_zero_u8();
    const int vstep = npyv_nlanes_u8;
    const int wstep = vstep * 4;
    for (npy_intp n = len & -wstep; i < n; i += wstep) {
        npyv_u8 a = npyv_load_u8(ip + i + vstep*0);
        npyv_u8 b = npyv_load_u8(ip + i + vstep*1);
        npyv_u8 c = npyv_load_u8(ip + i + vstep*2);
        npyv_u8 d = npyv_load_u8(ip + i + vstep*3);
        npyv_b8 m_a = npyv_cmpeq_u8(a, zero);
        npyv_b8 m_b = npyv_cmpeq_u8(b, zero);
        npyv_b8 m_c = npyv_cmpeq_u8(c, zero);
        npyv_b8 m_d = npyv_cmpeq_u8(d, zero);
        npyv_b8 m_ab = npyv_and_b8(m_a, m_b);
        npyv_b8 m_cd = npyv_and_b8(m_c, m_d);
        npy_uint64 m = npyv_tobits_b8(npyv_and_b8(m_ab, m_cd));
    #if NPY_SIMD == 512
        if (m != NPY_MAX_UINT64) {
    #else
        if ((npy_int64)m != ((1LL << vstep) - 1)) {
    #endif
            break;
        }
    }
    npyv_cleanup();
#endif // NPY_SIMD
    for (; i < len; ++i) {
        if (ip[i]) {
            *mindx = i;
            return 0;
        }
    }
    *mindx = 0;
    return 0;
}
